# embedding.py
import torch
from transformers import AutoModelForCausalLM

def compute_embeddings(texts, device, tokenizer, model_path):
    model = AutoModelForCausalLM.from_pretrained(model_path, output_hidden_states=True, return_dict_in_generate=True)
    model.to(device)
    # model = AutoModelForCausalLM.from_pretrained(model_path, output_hidden_states=True, return_dict_in_generate=True, device_map = "auto")
    model.eval()
    tokenizer.pad_token = tokenizer.eos_token
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    hidden_states = outputs.hidden_states
    input_ids = inputs["input_ids"]
    lm_head_weight = model.lm_head.weight
    return hidden_states, input_ids, lm_head_weight